"""Test demonstration storing system."""
import pytest
from cloudpathlib import implementation_registry
from cloudpathlib.local import LocalGSClient, local_gs_implementation
from pathlib import Path
from typing import Optional

from bigym.action_modes import JointPositionActionMode, TorqueActionMode, ActionMode
from bigym.bigym_env import BiGymEnv
from bigym.envs.reach_target import ReachTarget
from bigym.envs.move_plates import MovePlate
from bigym.envs.mainpulation import StackBlocks
from bigym.utils.observation_config import CameraConfig, ObservationConfig
from demonstrations.demo import Demo, LightweightDemo
from demonstrations.demo_recorder import DemoRecorder
from demonstrations.demo_store import (
    DemoStore,
    DemoNotFoundError,
    TooManyDemosRequestedError,
)
from demonstrations.utils import ObservationMode, Metadata
from demonstrations.const import SAFETENSORS_SUFFIX

from tests.test_demos import assert_timesteps_equal

ENV_CLASSES = [ReachTarget, MovePlate, StackBlocks]
ACTION_MODES = [JointPositionActionMode, TorqueActionMode]

NUM_DEMOS_PER_ENV = 2

STATE_CONFIG = ObservationConfig()
PIXEL_CONFIG = ObservationConfig(
    cameras=[CameraConfig(name="head", rgb=True, depth=False, resolution=(4, 4))]
)


@pytest.fixture
def mock_gs_bucket(monkeypatch):
    monkeypatch.setitem(implementation_registry, "gs", local_gs_implementation)
    yield
    # Clean up the local storage directory
    demo_store = DemoStore.google_cloud()
    demo_store._path.rmtree()
    LocalGSClient.reset_default_storage_dir()


class TestDemoStore:
    """Test demonstration storing system."""

    @staticmethod
    def record_demo(
        env: BiGymEnv,
        recorder: DemoRecorder,
        length: int = 2,
        seed: Optional[int] = None,
        is_lightweight: bool = False,
    ):
        env.reset(seed=seed)
        recorder.record(env, lightweight_demo=is_lightweight)
        for _ in range(length):
            action = env.action_space.sample()
            timestep = env.step(action)
            recorder.add_timestep(timestep, action)
        recorder.stop()

    @staticmethod
    def generate_demos(
        env_classes: list[type[BiGymEnv]] = ENV_CLASSES,
        action_modes: list[type[ActionMode]] = ACTION_MODES,
        obs_modes: list[ObservationMode] = list(ObservationMode),
        num_demos_per_env: int = NUM_DEMOS_PER_ENV,
    ) -> list[Demo]:
        demos = {}
        recorder = DemoRecorder()
        seed = 0
        for env_class in env_classes:
            for action_mode in action_modes:
                for obs_mode in obs_modes:
                    observation_config = (
                        PIXEL_CONFIG
                        if obs_mode == ObservationMode.Pixel
                        else STATE_CONFIG
                    )
                    is_lightweight = obs_mode == ObservationMode.Lightweight
                    for _ in range(num_demos_per_env):
                        env: BiGymEnv = env_class(
                            action_mode=action_mode(),
                            observation_config=observation_config,
                        )
                        TestDemoStore.record_demo(
                            env, recorder, seed=seed, is_lightweight=is_lightweight
                        )
                        demo = recorder.demo
                        demos[seed] = demo
                        seed += 1
                    seed += 1
                seed += 1
            seed += 1
        return demos

    def test_upload_and_download_of_a_demo(self, mock_gs_bucket):
        recorder = DemoRecorder()
        env: BiGymEnv = ReachTarget(action_mode=JointPositionActionMode())
        TestDemoStore.record_demo(env, recorder, seed=42)
        demo_to_store = recorder.demo

        metadata = Metadata.from_env(env)

        demo_store = DemoStore.google_cloud()
        demo_store.upload_demo(demo_to_store)
        demos_from_store = demo_store.get_demos(metadata, always_decimate=False)
        assert len(demos_from_store) == 1
        demo_from_store = demos_from_store[0]
        assert_timesteps_equal(demo_to_store, demo_from_store)

    @staticmethod
    def _test_upload_and_download_multiple_demos(
        env_classes: list[type[BiGymEnv]] = ENV_CLASSES,
        action_modes: list[type[ActionMode]] = ACTION_MODES,
        obs_modes: list[ObservationMode] = list(ObservationMode),
        num_demos_per_env: int = NUM_DEMOS_PER_ENV,
    ):
        demos_to_store = TestDemoStore.generate_demos(
            env_classes=env_classes,
            action_modes=action_modes,
            obs_modes=obs_modes,
            num_demos_per_env=num_demos_per_env,
        )

        demo_store = DemoStore.google_cloud()
        demo_store.upload_demos(list(demos_to_store.values()))

        for env_class in env_classes:
            for action_mode in action_modes:
                for obs_mode in obs_modes:
                    observation_config = (
                        PIXEL_CONFIG
                        if obs_mode == ObservationMode.Pixel
                        else STATE_CONFIG
                    )
                    metadata = Metadata.for_demo_store(
                        env_class,
                        action_mode,
                        obs_mode=obs_mode,
                        observation_config=observation_config,
                        action_mode_absolute=None
                        if action_mode == TorqueActionMode
                        else False,
                    )
                    demos_from_store = demo_store.get_demos(
                        metadata, always_decimate=False
                    )
                    if obs_mode == ObservationMode.Lightweight:
                        assert len(demos_from_store) == num_demos_per_env * len(
                            obs_modes
                        )
                    else:
                        assert len(demos_from_store) == num_demos_per_env
                    for demo_from_store in demos_from_store:
                        demo_to_store = demos_to_store[demo_from_store.metadata.seed]
                        if obs_mode == ObservationMode.Lightweight:
                            demo_to_store = LightweightDemo.from_demo(demo_to_store)
                    assert_timesteps_equal(demo_from_store, demo_to_store)

    def test_upload_and_download_of_demos(self, mock_gs_bucket):
        self._test_upload_and_download_multiple_demos(
            env_classes=[ReachTarget],
            action_modes=[JointPositionActionMode],
            obs_modes=[ObservationMode.Lightweight],
            num_demos_per_env=2,
        )

    def test_upload_and_download_of_demos_with_multiple_envs(self, mock_gs_bucket):
        self._test_upload_and_download_multiple_demos(
            env_classes=ENV_CLASSES,
            action_modes=[JointPositionActionMode],
            obs_modes=[ObservationMode.Lightweight],
            num_demos_per_env=1,
        )

    def test_upload_and_download_of_demos_with_multiple_action_modes(
        self, mock_gs_bucket
    ):
        self._test_upload_and_download_multiple_demos(
            env_classes=[ReachTarget],
            action_modes=ACTION_MODES,
            obs_modes=[ObservationMode.Lightweight],
            num_demos_per_env=1,
        )

    def test_upload_and_download_of_demos_with_multiple_obs_modes(self, mock_gs_bucket):
        self._test_upload_and_download_multiple_demos(
            env_classes=[ReachTarget],
            action_modes=[JointPositionActionMode],
            obs_modes=list(ObservationMode),
            num_demos_per_env=1,
        )

    def test_correct_file_structure(self, mock_gs_bucket):
        # Load the demos from the test_data folder and upload them to the cloud
        path = Path(__file__).parent / "data/safetensors"
        demo_store = DemoStore.google_cloud()
        demo_store.upload_safetensors(list(path.rglob(f"*{SAFETENSORS_SUFFIX}")))

        for env_class in ENV_CLASSES:
            for action_mode in ACTION_MODES:
                for obs_mode in list(ObservationMode):
                    observation_config = (
                        PIXEL_CONFIG
                        if obs_mode == ObservationMode.Pixel
                        else STATE_CONFIG
                    )
                    metadata = Metadata.for_demo_store(
                        env_class,
                        action_mode,
                        obs_mode=obs_mode,
                        observation_config=observation_config,
                        action_mode_absolute=None
                        if action_mode == TorqueActionMode
                        else False,
                        floating_dofs=[],
                    )
                    paths = demo_store.list_demo_paths(metadata)
                    assert len(paths) > 0
                    path = paths[0]
                    expected_path = (
                        demo_store._path
                        / metadata.env_name
                        / metadata.get_action_mode_description()
                        / obs_mode.value
                    )
                    if obs_mode == ObservationMode.Pixel:
                        expected_path /= metadata.get_camera_description()
                    assert path.parent == expected_path

    def test_get_demo_with_new_observations(self, mock_gs_bucket):
        env: BiGymEnv = ReachTarget(
            action_mode=JointPositionActionMode(absolute=True),
            observation_config=ObservationConfig(
                cameras=[
                    CameraConfig(name=name, resolution=(4, 4))
                    for name in ["head", "right_wrist", "left_wrist"]
                ],
            ),
        )
        env.reset()
        original_metadata = Metadata.from_env(env)
        heavy_recorder = DemoRecorder()
        lightweight_recorder = DemoRecorder()
        heavy_recorder.record(env)
        lightweight_recorder.record(env, lightweight_demo=True)
        for _ in range(5):
            action = env.action_space.sample()
            timestep = env.step(action)
            heavy_recorder.add_timestep(timestep, action)
            lightweight_recorder.add_timestep(timestep, action)
        heavy_recorder.stop()
        lightweight_recorder.stop()

        heavy_demo = heavy_recorder.demo
        lightweight_demo = lightweight_recorder.demo
        for timestep in lightweight_demo.timesteps:
            assert timestep.observation == {}
            assert timestep.reward is None

        demo_store = DemoStore.google_cloud()
        demo_store.upload_demo(lightweight_demo)
        demos_from_store = demo_store.get_demos(
            original_metadata, always_decimate=False
        )
        assert demo_store.lightweight_demo_exists(lightweight_demo.metadata)
        assert len(demos_from_store) == 1
        recreated_demo = demos_from_store[0]

        for timestep in recreated_demo.timesteps:
            obs = timestep.observation
            assert "rgb_head" in dict(obs)
            assert "rgb_right_wrist" in dict(obs)
            assert "rgb_left_wrist" in dict(obs)
            assert obs["rgb_head"].shape == (3, 4, 4)
            assert obs["rgb_right_wrist"].shape == (3, 4, 4)
            assert obs["rgb_left_wrist"].shape == (3, 4, 4)

        assert_timesteps_equal(heavy_demo, recreated_demo)

    def test_retrieve_n_demos(self, mock_gs_bucket):
        env: BiGymEnv = ReachTarget(
            action_mode=JointPositionActionMode(absolute=True),
            observation_config=ObservationConfig(
                cameras=[
                    CameraConfig(name=name, resolution=(4, 4))
                    for name in ["head", "right_wrist", "left_wrist"]
                ],
            ),
        )
        metadata = Metadata.from_env(env)
        recorder = DemoRecorder()
        demos = []
        for i in range(10):
            env.reset()
            recorder.record(env)
            action = env.action_space.sample()
            timestep = env.step(action)
            recorder.add_timestep(timestep, action)
            recorder.stop()
            demos.append(recorder.demo)

        demo_store = DemoStore.google_cloud()
        demo_store.upload_demos(demos)

        for i in range(0, 11, 2):
            demos_from_store = demo_store.get_demos(
                metadata, amount=i, always_decimate=False
            )
            assert len(demos_from_store) == i

    def test_implicit_saving_of_lightweight_demos(self, mock_gs_bucket):
        demo_to_store = _generate_simple_demo()
        demo_store = DemoStore.google_cloud()
        demo_store.upload_demo(demo_to_store)
        assert demo_store.lightweight_demo_exists(demo_to_store.metadata)

    def test_exception_thrown_if_demos_do_not_exist(self, mock_gs_bucket):
        demo_store = DemoStore.google_cloud()
        metadata = Metadata.for_demo_store(
            ReachTarget,
            JointPositionActionMode,
            obs_mode=ObservationMode.Lightweight,
        )
        with pytest.raises(DemoNotFoundError):
            demo_store.get_demos(metadata, always_decimate=False)

    def test_exception_thrown_if_lightweight_demos_do_not_exist(self, mock_gs_bucket):
        demo_store = DemoStore.google_cloud()
        metadata = Metadata.for_demo_store(
            ReachTarget,
            JointPositionActionMode,
            obs_mode=ObservationMode.State,
        )
        with pytest.raises(DemoNotFoundError):
            demo_store.get_demos(metadata, always_decimate=False)

    def test_exception_thrown_if_too_many_demos_requested(self, mock_gs_bucket):
        demo_to_store = _generate_simple_demo()
        demo_store = DemoStore.google_cloud()
        demo_store.upload_demo(demo_to_store)
        with pytest.raises(TooManyDemosRequestedError):
            demo_store.get_demos(
                demo_to_store.metadata, amount=1000, always_decimate=False
            )


def _generate_simple_demo():
    recorder = DemoRecorder()
    env: BiGymEnv = ReachTarget(action_mode=JointPositionActionMode())
    TestDemoStore.record_demo(env, recorder, seed=42)
    return recorder.demo
